Notebook prepared by Maxim Ziatdinov (March 2022)
In this notebook, we show how to learn the channel with best predictive capability in DKL-based active learning.
!pip install -q git+https://github.com/ziatdinovmax/gpax@v0.0.3
!pip install -U gdown
from warnings import filterwarnings
import numpy as np
import matplotlib.pyplot as plt
import gpax
import jax.numpy as jnp
from sklearn.model_selection import train_test_split
filterwarnings("ignore", category=UserWarning)
gpax.utils.enable_x64()
#@title Load helper functions 🛠️ { form-width: "20%" }
from typing import Tuple
def make_window(imgsrc: np.ndarray, window_size: int,
xpos: int, ypos: int) -> np.ndarray:
"""
Returns the portion of the image within the window given the
image (imgsrc), the x position and the y position
"""
imgsrc = imgsrc[int(xpos-window_size/2):int(xpos+window_size/2),
int(ypos-window_size/2):int(ypos+window_size/2)]
return imgsrc
def create_training_set(imgdata: np.ndarray, target: np.ndarray,
window_size: int) -> Tuple[np.ndarray]:
"""
Creates arrays with features (local subimages)
and targets (corresponding spectra) from hyperspectral data
"""
feature_arr, target_arr = [], []
pos = []
s1, s2 = imgdata.shape[:-1]
for i in range(s1):
for j in range(s2):
arr_loc = make_window(imgdata, window_size, i, j)
if arr_loc.shape[:-1] != (window_size, window_size):
continue
feature_arr.append(arr_loc)
target_arr.append(target[i, j])
pos.append([i, j])
return np.array(pos), np.array(feature_arr), np.array(target_arr)
!gdown https://drive.google.com/uc?id=1VdvoTJ1OaDo-qEIdgZJ1030aK5mBFcQh
Downloading... From: https://drive.google.com/uc?id=1VdvoTJ1OaDo-qEIdgZJ1030aK5mBFcQh To: /notebooks/PTO_BEPS_3um.npz 100%|██████████████████████████████████████| 62.1M/62.1M [00:02<00:00, 23.3MB/s]
BEPS = np.load("PTO_BEPS_3um.npz")
BEPS.files
['Image_Amplitude', 'Image_Phase', 'Image_Frequency', 'Image_Q_Factor', 'Image_Off_Field_Loop_Area', 'Image_On_Field_Loop_Area', 'Image_Loop_Width', 'Image_Average_Coercive_Field', 'Image_Negative_Coercive_Field', 'Image_Positive_Coercive_Field', 'Spectra_DC_Voltage', 'Spectra_Off_Field_Loop', 'Spectra_On_Field_Loop', 'Spectra_On_Field_Frequency_Loop', 'Spectra_Off_Field_Frequency_Loop']
We are going to work with the amplitude, phase, frequency, and Q-factor channels:
# DKL inputs
amplitude = BEPS['Image_Amplitude'].astype(np.float64)
phase = BEPS["Image_Phase"].astype(np.float64)
frequency = BEPS["Image_Frequency"].astype(np.float64)
qfactor = BEPS["Image_Q_Factor"].astype(np.float64)
# DKL targets
loop_area = BEPS['Image_Off_Field_Loop_Area']
norm_ = lambda x: (x - x.min()) / x.ptp()
# Concatenate input channels
all_channels = np.concatenate(
[
norm_(amplitude)[..., None],
norm_(phase)[..., None],
norm_(frequency)[..., None],
norm_(qfactor)[..., None]
],
axis=-1)
print(all_channels.shape, loop_area.shape)
# Plot
_, axes = plt.subplots(1, 4, figsize=(20, 5))
for i, ax in enumerate(axes):
ax.imshow(all_channels[..., i], origin='lower')
ax.axis('off')
(100, 100, 4) (100, 100)
Create training data (patches with amplitude/phase/grequency/q-factor data as inputs and loop areas as target):
indices_all, features_all, targets_all = create_training_set(all_channels, loop_area, 12)
#targets_all -= targets_all.mean()
#targets_all /= targets_all.std()
targets_all -= targets_all.min()
targets_all /= targets_all.ptp()
k = 450
_, axes = plt.subplots(1, 4, figsize=(20, 5))
for i, ax in enumerate(axes):
ax.imshow(features_all[k ,..., i], origin='lower',)
ax.imshow(features_all[k ,..., i], origin='lower')
Define helper functions for running Bayesian inference and active channel learning:
def step_mcmc(X, y, X_new):
key1, key2 = gpax.utils.get_keys()
data_dim = X.shape[-1]
dkl = gpax.DKL(data_dim, z_dim=2, kernel='RBF')
dkl.fit(key1, X, y, num_warmup=333, num_samples=333, num_chains=3, chain_method='vectorized')
mean, samples = dkl.predict_in_batches(key2, X_new)
var = samples.squeeze().var(0)
return mean, var
def get_best_channel(record):
return record[:,1].argmax()
def update_record(record, action, r):
new_r = (record[action, 0] * record[action, 1] + r) / (record[action, 0] + 1)
record[action, 0] += 1
record[action, 1] = new_r
return record
def get_reward(obj_history, obj):
"""A reward of +/-1 is given if the integral uncertainty at the current step
is smaller/larger than the integral uncertainty at the previous step"""
if jnp.nansum(obj) < obj_history[-1]:
r = 1
else:
r = -1
return r
def update_datapoints(next_point_idx, train, test):
X_train, y_train, indices_train = train
X_test, y_test, indices_test = test
X_train = jnp.append(X_train, X_test[:, next_point_idx:next_point_idx+1], axis=1)
X_test = jnp.delete(X_test, next_point_idx, axis=1)
y_train = jnp.append(y_train, y_test[:, next_point_idx:next_point_idx+1], axis=-1)
y_test = jnp.delete(y_test, next_point_idx, axis=-1)
indices_train = jnp.append(indices_train, indices_test[next_point_idx:next_point_idx+1], axis=0)
indices_test = jnp.delete(indices_test, next_point_idx, axis=0)
return (X_train, y_train, indices_train), (X_test, y_test, indices_test)
Create training and 'unmeasured' dataset. We start with having only 0.5% of the grid points 'measured'.
n, h, w, c = features_all.shape
features_all = features_all.reshape(n, h*w, c)
(X_train, X_unmeasured, y_train, y_unmeasured,
indices_train, indices_unmeasured) = train_test_split(
features_all, targets_all, indices_all, test_size=0.99, random_state=1)
X_train = X_train.transpose(2, 0, 1)
X_unmeasured = X_unmeasured.transpose(2, 0, 1)
y_train = y_train[None].repeat(X_train.shape[0], axis=0)
y_unmeasured = y_unmeasured[None].repeat(X_train.shape[0], axis=0)
X_train.shape, X_unmeasured.shape, y_train.shape, y_unmeasured.shape
((4, 79, 144), (4, 7842, 144), (4, 79), (4, 7842))
Run active channel learning:
warmup_steps = 5
exploration_steps = 30
# epsilon in epsilon-greedy policy
#eps = np.ones(exploration_steps) * 0.4 # constant
eps = np.linspace(0.4, 0.1, exploration_steps) # annealing
np.random.seed(55) # rng seed for epsilon-greedy sampling
record = np.zeros((X_train.shape[0], 2))
channel_choices = []
var_history = []
def plot_acq():
p = indices_unmeasured[next_point_idx]
plt.figure(figsize=(6, 6))
plt.scatter(indices_unmeasured[:, 1], indices_unmeasured[:, 0], c=obj, s=10, marker='s')
plt.scatter(p[1], p[0], marker='x', s=64, c='red')
plt.show()
# Warm-up phase
for w in range(warmup_steps):
print("Warmup step {}/{}".format(w+1, warmup_steps))
# Get/update DKL posterior
mean, var = step_mcmc(X_train, y_train, X_unmeasured)
record[:, 0] += 1
# Select channel that resulted in lowest predictive uncertainty
var_reduced = np.nansum(var, axis=-1)
idx = np.argmin(var_reduced)
# Update records
var_history.append(var_reduced[idx].item())
channel_choices.append(idx)
record[idx, 1] += 1
# compute the acquisition function for the selected channel
#obj = mean[idx] + 2*var[idx]
obj = var[idx]
# get the next measurement point
next_point_idx = obj.argmax()
plot_acq()
# Do "measurement"
measured_point = y_unmeasured[:, next_point_idx]
# update datapoints
(X_train, y_train, indices_train), (X_unmeasured, y_unmeasured, indices_unmeasured) = update_datapoints(
next_point_idx, (X_train, y_train, indices_train), (X_unmeasured, y_unmeasured, indices_unmeasured)
)
# Average over the number of warmup steps
record[:, 1] = record[:, 1] / warmup_steps
# Exploration-exploitation phase
for e in range(exploration_steps):
print("\nExploration step {}/{}".format(e+1, exploration_steps))
# Choose channel according to epsilon-greedy policy
eps_i = eps[e] if e < exploration_steps else eps[-1]
if np.random.random() > eps_i:
idx = get_best_channel(record)
else:
idx = np.random.randint(record.shape[0])
channel_choices.append(idx)
print("Using channel {}".format(idx+1))
# Update DKL posterior using the sampled channel
mean, var = step_mcmc(X_train[idx], y_train[idx], X_unmeasured[idx])
# Get reward
r = get_reward(var_history, var)
# Update records
record = update_record(record, idx, r)
var_history.append(jnp.nansum(var).item())
# Compute acqusition function and do measurement in the suggested point
#obj = mean + 2*var
obj = var
next_point_idx = obj.argmax()
plot_acq()
measured_point = y_unmeasured[:, next_point_idx]
# update datapoints
(X_train, y_train, indices_train), (X_unmeasured, y_unmeasured, indices_unmeasured) = update_datapoints(
next_point_idx, (X_train, y_train, indices_train), (X_unmeasured, y_unmeasured, indices_unmeasured)
)
# Print the running rewards
print("\nRewards:")
for i, r in enumerate(record):
print("channel {}: counts {} reward (avg) {}".format(i+1, (int(r[0])), np.round(r[1], 3)))
# Store current points
np.savez("hmcDKL_channels_polar_norm01.npz",
X_train=X_train, y_train=y_train, X_unmeasured=X_unmeasured,
y_unmeasured=y_unmeasured, indices_unmeasured=indices_unmeasured,
indices_measured=indices_train, record=record, channels_id=channel_choices,
var_history=var_history)
Warmup step 1/5
sample: 100%|██████████| 666/666 [20:36<00:00, 1.86s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.65 6.66 5.20 1.12 15.83 124.52 1.01
k_length[0,1] 7.41 6.20 5.29 1.09 15.20 143.70 1.01
k_length[1,0] 6.46 6.00 4.37 1.30 13.37 177.93 1.01
k_length[1,1] 6.37 5.60 4.46 0.93 13.08 179.68 1.00
k_length[2,0] 6.01 5.38 4.41 0.57 12.00 146.79 1.01
k_length[2,1] 4.24 4.60 2.57 0.49 9.67 133.14 1.05
k_length[3,0] 5.12 5.29 3.32 0.60 10.30 339.00 1.01
k_length[3,1] 4.71 4.36 3.29 0.70 9.87 299.90 1.00
k_scale[0] 0.21 0.21 0.15 0.04 0.39 545.33 1.00
k_scale[1] 0.25 0.19 0.18 0.05 0.49 431.77 1.00
k_scale[2] 0.28 0.26 0.20 0.04 0.55 555.78 1.00
k_scale[3] 0.24 0.27 0.18 0.04 0.48 621.41 1.01
noise[0] 0.01 0.01 0.01 0.00 0.02 200.87 1.00
noise[1] 0.01 0.00 0.01 0.00 0.01 238.37 1.01
noise[2] 0.02 0.02 0.01 0.00 0.06 18.80 1.16
noise[3] 0.02 0.02 0.02 0.00 0.05 68.18 1.01
Warmup step 2/5
sample: 100%|██████████| 666/666 [20:16<00:00, 1.83s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.01 6.38 4.81 1.06 14.28 68.24 1.04
k_length[0,1] 8.36 7.45 5.94 1.59 16.67 104.35 1.02
k_length[1,0] 5.87 5.54 4.10 1.10 11.48 142.60 1.03
k_length[1,1] 6.28 5.19 4.52 0.95 12.46 146.18 1.02
k_length[2,0] 5.76 6.52 3.78 0.61 12.56 212.28 1.01
k_length[2,1] 5.03 4.75 3.23 0.45 11.11 116.87 1.01
k_length[3,0] 4.73 4.89 2.82 0.69 10.69 80.24 1.02
k_length[3,1] 5.36 5.96 3.39 0.29 12.41 87.49 1.02
k_scale[0] 0.21 0.21 0.16 0.04 0.37 494.01 1.00
k_scale[1] 0.24 0.21 0.18 0.04 0.46 577.60 1.00
k_scale[2] 0.31 0.26 0.23 0.03 0.67 302.70 1.01
k_scale[3] 0.23 0.23 0.16 0.04 0.43 366.60 1.01
noise[0] 0.01 0.01 0.01 0.00 0.02 210.32 1.00
noise[1] 0.01 0.00 0.01 0.00 0.01 234.36 1.01
noise[2] 0.02 0.02 0.02 0.00 0.05 18.77 1.14
noise[3] 0.02 0.01 0.01 0.00 0.04 53.75 1.03
Warmup step 3/5
sample: 100%|██████████| 666/666 [20:59<00:00, 1.89s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.51 6.94 5.05 0.98 16.15 66.18 1.04
k_length[0,1] 8.43 8.30 5.68 0.97 17.50 95.33 1.03
k_length[1,0] 6.79 6.49 4.62 1.29 13.79 114.86 1.00
k_length[1,1] 6.29 5.08 4.74 1.03 11.93 129.11 1.01
k_length[2,0] 4.94 5.01 3.03 0.26 11.23 81.28 1.01
k_length[2,1] 5.89 6.09 3.81 0.41 12.46 71.91 1.02
k_length[3,0] 5.42 5.26 3.52 0.53 11.96 178.48 1.00
k_length[3,1] 4.88 5.25 2.91 0.51 11.64 89.41 1.01
k_scale[0] 0.20 0.18 0.15 0.04 0.39 360.75 1.00
k_scale[1] 0.26 0.24 0.19 0.03 0.49 498.69 1.00
k_scale[2] 0.30 0.28 0.21 0.03 0.61 220.40 1.02
k_scale[3] 0.26 0.28 0.19 0.04 0.47 370.66 1.01
noise[0] 0.01 0.01 0.01 0.00 0.02 127.84 1.01
noise[1] 0.01 0.00 0.01 0.00 0.01 206.99 1.00
noise[2] 0.02 0.02 0.01 0.00 0.05 41.09 1.07
noise[3] 0.02 0.01 0.01 0.00 0.03 18.85 1.17
Warmup step 4/5
sample: 100%|██████████| 666/666 [21:16<00:00, 1.92s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.64 6.64 5.43 1.27 15.19 119.77 1.02
k_length[0,1] 7.60 7.29 5.01 1.45 15.10 83.57 1.04
k_length[1,0] 6.54 6.73 4.43 1.10 13.10 108.27 1.03
k_length[1,1] 7.33 6.63 4.67 0.99 15.80 115.38 1.03
k_length[2,0] 5.64 6.36 3.72 0.50 11.63 177.36 1.01
k_length[2,1] 5.06 5.22 3.30 0.46 10.90 138.11 1.01
k_length[3,0] 5.30 6.02 3.25 0.42 11.71 150.47 1.02
k_length[3,1] 4.66 4.82 2.85 0.53 10.56 140.15 1.02
k_scale[0] 0.21 0.28 0.15 0.05 0.38 512.29 1.00
k_scale[1] 0.24 0.20 0.18 0.04 0.49 522.89 1.01
k_scale[2] 0.31 0.30 0.23 0.05 0.60 366.96 1.00
k_scale[3] 0.23 0.20 0.17 0.04 0.44 479.28 1.01
noise[0] 0.01 0.00 0.01 0.00 0.02 303.50 1.00
noise[1] 0.01 0.00 0.01 0.00 0.01 215.85 1.01
noise[2] 0.02 0.02 0.01 0.00 0.05 45.23 1.02
noise[3] 0.02 0.01 0.01 0.00 0.04 47.56 1.06
Warmup step 5/5
sample: 100%|██████████| 666/666 [21:04<00:00, 1.90s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.94 8.40 4.89 1.43 17.57 29.63 1.13
k_length[0,1] 9.07 8.39 6.28 1.71 17.89 25.83 1.13
k_length[1,0] 6.92 8.89 4.45 0.73 12.90 99.06 1.04
k_length[1,1] 7.12 6.59 5.02 1.37 14.36 103.18 1.01
k_length[2,0] 6.64 6.63 4.65 0.86 15.24 43.60 1.06
k_length[2,1] 4.00 4.49 2.16 0.38 9.66 9.02 1.14
k_length[3,0] 4.97 5.48 3.06 0.59 10.91 63.48 1.07
k_length[3,1] 4.78 5.10 2.60 0.37 11.37 60.72 1.04
k_scale[0] 0.21 0.20 0.15 0.04 0.37 289.68 1.01
k_scale[1] 0.24 0.21 0.18 0.04 0.46 451.61 1.00
k_scale[2] 0.29 0.24 0.21 0.04 0.59 255.03 1.01
k_scale[3] 0.22 0.17 0.17 0.04 0.43 340.23 1.01
noise[0] 0.01 0.00 0.01 0.00 0.01 155.95 1.04
noise[1] 0.01 0.00 0.01 0.00 0.01 89.77 1.03
noise[2] 0.02 0.02 0.01 0.00 0.05 15.82 1.21
noise[3] 0.01 0.01 0.01 0.00 0.02 57.89 1.04
Exploration step 1/30 Using channel 4
sample: 100%|██████████| 666/666 [16:39<00:00, 1.50s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 5.09 6.49 2.60 0.39 11.89 94.58 1.05
k_length[0,1] 6.06 7.00 3.61 0.58 14.23 166.89 1.03
k_scale[0] 0.23 0.22 0.17 0.05 0.41 607.87 1.01
noise[0] 0.01 0.01 0.01 0.00 0.02 27.08 1.12
Rewards: channel 1: counts 5 reward (avg) 0.0 channel 2: counts 5 reward (avg) 1.0 channel 3: counts 5 reward (avg) 0.0 channel 4: counts 6 reward (avg) -0.167 Exploration step 2/30 Using channel 2
sample: 100%|██████████| 666/666 [16:20<00:00, 1.47s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 8.00 8.18 5.14 1.10 16.85 132.21 1.03
k_length[0,1] 6.50 6.28 4.43 1.16 13.04 265.19 1.01
k_scale[0] 0.24 0.22 0.18 0.04 0.45 656.77 1.00
noise[0] 0.01 0.00 0.01 0.00 0.01 487.86 1.00
Rewards: channel 1: counts 5 reward (avg) 0.0 channel 2: counts 6 reward (avg) 1.0 channel 3: counts 5 reward (avg) 0.0 channel 4: counts 6 reward (avg) -0.167 Exploration step 3/30 Using channel 2
sample: 100%|██████████| 666/666 [16:43<00:00, 1.51s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 6.87 6.44 4.74 1.32 14.21 157.10 1.03
k_length[0,1] 7.18 7.28 4.87 1.12 14.27 174.03 1.03
k_scale[0] 0.24 0.20 0.18 0.04 0.46 510.33 1.00
noise[0] 0.01 0.00 0.01 0.00 0.01 515.93 1.00
Rewards: channel 1: counts 5 reward (avg) 0.0 channel 2: counts 7 reward (avg) 0.714 channel 3: counts 5 reward (avg) 0.0 channel 4: counts 6 reward (avg) -0.167 Exploration step 4/30 Using channel 2
sample: 100%|██████████| 666/666 [16:47<00:00, 1.51s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 6.87 6.28 4.72 1.02 13.65 156.98 1.01
k_length[0,1] 7.48 8.24 5.03 0.93 15.12 251.41 1.01
k_scale[0] 0.25 0.24 0.18 0.04 0.49 602.29 1.00
noise[0] 0.01 0.00 0.01 0.00 0.01 587.93 1.01
Rewards: channel 1: counts 5 reward (avg) 0.0 channel 2: counts 8 reward (avg) 0.75 channel 3: counts 5 reward (avg) 0.0 channel 4: counts 6 reward (avg) -0.167 Exploration step 5/30 Using channel 3
sample: 100%|██████████| 666/666 [16:32<00:00, 1.49s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 5.83 5.91 3.52 0.41 13.17 22.27 1.13
k_length[0,1] 4.64 4.24 3.31 0.64 10.83 18.68 1.17
k_scale[0] 0.25 0.19 0.19 0.06 0.49 31.78 1.09
noise[0] 0.01 0.01 0.01 0.00 0.02 36.34 1.06
Rewards: channel 1: counts 5 reward (avg) 0.0 channel 2: counts 8 reward (avg) 0.75 channel 3: counts 6 reward (avg) -0.167 channel 4: counts 6 reward (avg) -0.167 Exploration step 6/30 Using channel 3
sample: 100%|██████████| 666/666 [17:04<00:00, 1.54s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.00 7.44 4.83 0.37 16.27 59.99 1.07
k_length[0,1] 4.23 5.21 2.02 0.33 10.71 38.99 1.08
k_scale[0] 0.25 0.23 0.19 0.04 0.47 154.64 1.00
noise[0] 0.01 0.01 0.01 0.00 0.02 45.66 1.03
Rewards: channel 1: counts 5 reward (avg) 0.0 channel 2: counts 8 reward (avg) 0.75 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 6 reward (avg) -0.167 Exploration step 7/30 Using channel 2
sample: 100%|██████████| 666/666 [16:59<00:00, 1.53s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 6.58 6.41 4.46 1.04 13.01 174.91 1.01
k_length[0,1] 7.52 7.10 5.23 1.30 14.95 164.90 1.01
k_scale[0] 0.23 0.17 0.18 0.03 0.42 868.89 1.00
noise[0] 0.01 0.00 0.01 0.00 0.01 362.27 1.01
Rewards: channel 1: counts 5 reward (avg) 0.0 channel 2: counts 9 reward (avg) 0.778 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 6 reward (avg) -0.167 Exploration step 8/30 Using channel 1
sample: 100%|██████████| 666/666 [11:38<00:00, 1.05s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 10.20 9.45 7.29 1.33 19.75 30.15 1.06
k_length[0,1] 6.99 7.97 4.51 1.50 13.63 80.67 1.05
k_scale[0] 0.20 0.16 0.15 0.04 0.39 602.96 1.00
noise[0] 0.01 0.00 0.01 0.00 0.01 473.80 1.02
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 9 reward (avg) 0.778 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 6 reward (avg) -0.167 Exploration step 9/30 Using channel 2
sample: 100%|██████████| 666/666 [17:11<00:00, 1.55s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.03 6.47 4.72 0.76 15.01 179.58 1.02
k_length[0,1] 6.90 6.12 4.67 1.23 14.56 152.58 1.03
k_scale[0] 0.23 0.20 0.18 0.04 0.46 661.90 1.00
noise[0] 0.01 0.00 0.01 0.00 0.01 451.50 1.00
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 10 reward (avg) 0.8 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 6 reward (avg) -0.167 Exploration step 10/30 Using channel 2
sample: 100%|██████████| 666/666 [17:04<00:00, 1.54s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.80 7.93 5.18 1.01 15.70 246.64 1.01
k_length[0,1] 6.88 6.63 4.56 1.26 14.66 177.99 1.01
k_scale[0] 0.23 0.19 0.18 0.04 0.45 830.05 1.00
noise[0] 0.01 0.00 0.01 0.00 0.01 521.82 1.00
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 11 reward (avg) 0.818 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 6 reward (avg) -0.167 Exploration step 11/30 Using channel 4
sample: 100%|██████████| 666/666 [17:32<00:00, 1.58s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 6.08 7.69 3.24 0.31 13.36 115.03 1.00
k_length[0,1] 5.29 5.88 2.71 0.45 12.56 73.31 1.03
k_scale[0] 0.21 0.18 0.16 0.04 0.41 471.00 1.00
noise[0] 0.01 0.01 0.01 0.00 0.02 72.70 1.04
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 11 reward (avg) 0.818 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 12/30 Using channel 2
sample: 100%|██████████| 666/666 [17:08<00:00, 1.54s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 8.12 8.74 5.16 1.19 16.89 242.96 1.01
k_length[0,1] 7.25 7.78 4.98 1.06 13.95 228.58 1.01
k_scale[0] 0.24 0.22 0.19 0.04 0.45 648.77 1.00
noise[0] 0.01 0.00 0.01 0.00 0.01 586.95 1.00
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 12 reward (avg) 0.833 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 13/30 Using channel 2
sample: 100%|██████████| 666/666 [16:36<00:00, 1.50s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 6.89 6.75 4.51 1.47 14.85 168.06 1.01
k_length[0,1] 8.47 8.36 5.49 1.40 17.67 203.52 1.00
k_scale[0] 0.24 0.21 0.19 0.04 0.45 742.90 1.00
noise[0] 0.01 0.00 0.01 0.00 0.01 569.46 1.00
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 13 reward (avg) 0.846 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 14/30 Using channel 2
sample: 100%|██████████| 666/666 [17:37<00:00, 1.59s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.42 7.41 4.64 1.08 16.39 172.40 1.01
k_length[0,1] 7.70 7.34 5.00 1.15 16.42 141.45 1.02
k_scale[0] 0.23 0.19 0.17 0.04 0.42 764.37 1.00
noise[0] 0.01 0.00 0.00 0.00 0.01 436.02 1.00
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 14 reward (avg) 0.714 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 15/30 Using channel 2
sample: 100%|██████████| 666/666 [18:13<00:00, 1.64s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.40 7.65 4.60 1.21 15.70 71.71 1.03
k_length[0,1] 8.12 8.10 5.43 1.11 17.30 94.73 1.02
k_scale[0] 0.23 0.19 0.18 0.04 0.44 580.67 1.00
noise[0] 0.01 0.00 0.00 0.00 0.01 481.00 1.00
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 15 reward (avg) 0.733 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 16/30 Using channel 2
sample: 100%|██████████| 666/666 [17:46<00:00, 1.60s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 8.06 7.69 5.28 1.29 17.68 106.12 1.06
k_length[0,1] 7.73 8.49 4.86 1.28 17.15 156.16 1.04
k_scale[0] 0.23 0.19 0.17 0.04 0.43 398.20 1.00
noise[0] 0.01 0.00 0.00 0.00 0.01 327.33 1.01
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 16 reward (avg) 0.75 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 17/30 Using channel 2
sample: 100%|██████████| 666/666 [17:41<00:00, 1.59s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.61 7.59 4.95 1.21 15.82 33.17 1.05
k_length[0,1] 7.48 7.70 4.88 1.42 16.58 60.75 1.05
k_scale[0] 0.24 0.19 0.19 0.04 0.46 255.13 1.01
noise[0] 0.01 0.00 0.00 0.00 0.01 473.29 1.00
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 17 reward (avg) 0.647 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 18/30 Using channel 2
sample: 100%|██████████| 666/666 [17:50<00:00, 1.61s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 6.80 7.80 4.56 1.01 13.08 188.51 1.03
k_length[0,1] 8.99 7.30 6.62 1.22 17.77 146.26 1.02
k_scale[0] 0.24 0.21 0.18 0.05 0.46 621.95 1.01
noise[0] 0.00 0.00 0.00 0.00 0.01 318.98 1.01
Rewards: channel 1: counts 6 reward (avg) -0.167 channel 2: counts 18 reward (avg) 0.667 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 19/30 Using channel 1
sample: 100%|██████████| 666/666 [17:10<00:00, 1.55s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 10.18 9.66 6.79 1.86 21.15 143.37 1.01
k_length[0,1] 9.05 8.69 5.85 1.57 19.14 111.90 1.02
k_scale[0] 0.21 0.19 0.15 0.03 0.41 372.34 1.01
noise[0] 0.01 0.00 0.00 0.00 0.01 414.61 1.00
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 18 reward (avg) 0.667 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 20/30 Using channel 2
sample: 100%|██████████| 666/666 [17:53<00:00, 1.61s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.76 7.06 4.92 1.40 17.00 123.53 1.04
k_length[0,1] 7.90 7.20 5.14 1.43 16.86 99.18 1.05
k_scale[0] 0.25 0.21 0.18 0.04 0.47 568.94 1.01
noise[0] 0.00 0.00 0.00 0.00 0.01 222.07 1.02
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 19 reward (avg) 0.684 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 21/30 Using channel 2
sample: 100%|██████████| 666/666 [17:24<00:00, 1.57s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.92 7.94 4.82 1.29 17.11 89.47 1.03
k_length[0,1] 8.02 7.05 5.24 1.55 17.26 107.35 1.03
k_scale[0] 0.23 0.17 0.18 0.05 0.41 834.79 1.00
noise[0] 0.00 0.00 0.00 0.00 0.01 350.07 1.01
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 20 reward (avg) 0.7 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 22/30 Using channel 2
sample: 100%|██████████| 666/666 [18:15<00:00, 1.64s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 8.99 9.12 5.61 1.37 19.50 65.78 1.05
k_length[0,1] 7.21 7.21 4.71 1.48 15.59 72.67 1.05
k_scale[0] 0.22 0.20 0.17 0.05 0.40 416.69 1.01
noise[0] 0.00 0.00 0.00 0.00 0.01 290.90 1.01
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 21 reward (avg) 0.619 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 23/30 Using channel 2
sample: 100%|██████████| 666/666 [18:29<00:00, 1.67s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 8.53 7.56 5.68 1.35 18.84 106.66 1.03
k_length[0,1] 7.37 7.32 4.74 1.25 15.39 66.47 1.06
k_scale[0] 0.23 0.20 0.17 0.04 0.43 469.97 1.01
noise[0] 0.00 0.00 0.00 0.00 0.01 337.31 1.00
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 22 reward (avg) 0.545 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 24/30 Using channel 2
sample: 100%|██████████| 666/666 [18:19<00:00, 1.65s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 9.25 8.73 6.14 0.87 20.52 76.56 1.06
k_length[0,1] 7.08 6.71 4.54 1.13 15.27 68.22 1.05
k_scale[0] 0.23 0.20 0.17 0.03 0.44 595.81 1.00
noise[0] 0.00 0.00 0.00 0.00 0.01 309.22 1.01
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 23 reward (avg) 0.565 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 25/30 Using channel 2
sample: 100%|██████████| 666/666 [18:14<00:00, 1.64s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 8.07 7.80 5.32 0.97 17.11 104.82 1.03
k_length[0,1] 7.97 7.82 5.15 1.49 16.99 95.44 1.04
k_scale[0] 0.23 0.20 0.18 0.04 0.44 572.08 1.00
noise[0] 0.00 0.00 0.00 0.00 0.01 294.96 1.02
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 24 reward (avg) 0.583 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 26/30 Using channel 2
sample: 100%|██████████| 666/666 [18:33<00:00, 1.67s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 9.07 8.95 5.88 1.48 18.97 89.87 1.02
k_length[0,1] 7.56 7.95 4.49 1.16 16.33 99.48 1.02
k_scale[0] 0.23 0.18 0.18 0.04 0.42 402.16 1.02
noise[0] 0.00 0.00 0.00 0.00 0.01 203.89 1.01
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 25 reward (avg) 0.6 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 27/30 Using channel 2
sample: 100%|██████████| 666/666 [18:35<00:00, 1.67s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 7.81 7.16 5.22 1.44 16.16 127.77 1.02
k_length[0,1] 7.95 8.01 4.92 1.42 16.80 103.17 1.02
k_scale[0] 0.23 0.19 0.18 0.04 0.43 935.50 1.00
noise[0] 0.00 0.00 0.00 0.00 0.01 307.89 1.01
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 26 reward (avg) 0.615 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 28/30 Using channel 2
sample: 100%|██████████| 666/666 [18:41<00:00, 1.68s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 8.75 9.19 5.14 1.36 20.18 100.88 1.03
k_length[0,1] 7.75 7.37 5.22 1.53 16.05 82.76 1.04
k_scale[0] 0.23 0.19 0.17 0.03 0.43 603.94 1.00
noise[0] 0.00 0.00 0.00 0.00 0.01 270.25 1.02
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 27 reward (avg) 0.556 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 29/30 Using channel 2
sample: 100%|██████████| 666/666 [17:55<00:00, 1.61s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 8.65 8.33 5.77 1.42 18.52 150.03 1.01
k_length[0,1] 7.43 7.28 4.84 1.33 16.02 83.56 1.05
k_scale[0] 0.26 0.21 0.20 0.04 0.48 481.47 1.00
noise[0] 0.00 0.00 0.00 0.00 0.01 206.47 1.01
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 28 reward (avg) 0.5 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286 Exploration step 30/30 Using channel 2
sample: 100%|██████████| 666/666 [18:49<00:00, 1.70s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 8.99 7.88 6.48 1.20 18.23 160.91 1.01
k_length[0,1] 6.79 6.17 4.69 1.31 14.54 62.63 1.05
k_scale[0] 0.26 0.24 0.19 0.05 0.51 226.06 1.01
noise[0] 0.00 0.00 0.00 0.00 0.01 233.25 1.00
Rewards: channel 1: counts 7 reward (avg) -0.286 channel 2: counts 29 reward (avg) 0.448 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286
dataset = np.load("hmcDKL_channels_polar_norm01.npz")
X_train=dataset["X_train"]
y_train=dataset["y_train"]
X_unmeasured=dataset["X_unmeasured"]
y_unmeasured=dataset["y_unmeasured"]
indices_unmeasured=dataset["indices_unmeasured"]
indices_train=dataset["indices_measured"]
record=dataset["record"]
channel_choices=dataset["channels_id"].tolist()
var_history=dataset["var_history"].tolist()
View the results:
for i, r in enumerate(record):
print("channel {}: counts {} reward (avg) {}".format(i+1, (int(r[0])), np.round(r[1], 3)))
channel 1: counts 7 reward (avg) -0.286 channel 2: counts 29 reward (avg) 0.448 channel 3: counts 7 reward (avg) -0.286 channel 4: counts 7 reward (avg) -0.286
plt.figure(figsize = (5, 3), dpi=100)
plt.plot(np.arange(1, exploration_steps+warmup_steps+1), var_history, c='k')
plt.scatter(np.arange(1, exploration_steps+warmup_steps+1), var_history,
c=[m+1 for m in channel_choices],
s=128, cmap='rainbow', alpha=1)
plt.xlabel("Exploration step", fontsize=14)
plt.ylabel("Uncertainty", fontsize=14)
plt.xticks(size=14)
plt.yticks(size=14)
cbar = plt.colorbar(label = "Channel Number")
cbar.set_ticks([1, 2, 3, 4])
Run DKL for all the channels and plot prediction:
mean, var = step_mcmc(X_train, y_train, X_unmeasured)
sample: 100%|██████████| 666/666 [29:42<00:00, 2.68s/it]
mean std median 5.0% 95.0% n_eff r_hat
k_length[0,0] 6.04 4.32 4.79 2.05 10.29 40.34 1.06
k_length[0,1] 11.93 8.18 9.88 2.19 22.15 51.32 1.07
k_length[1,0] 5.70 5.26 4.12 1.54 10.31 53.43 1.04
k_length[1,1] 10.51 8.96 8.53 1.77 20.10 108.32 1.01
k_length[2,0] 8.44 11.94 5.60 0.83 16.22 51.53 1.08
k_length[2,1] 3.80 3.64 2.69 0.77 7.44 54.60 1.08
k_length[3,0] 4.33 5.02 2.50 0.70 9.84 41.81 1.06
k_length[3,1] 8.29 7.40 6.75 0.86 15.97 41.88 1.05
k_scale[0] 0.21 0.18 0.16 0.04 0.38 176.10 1.01
k_scale[1] 0.23 0.19 0.18 0.05 0.45 332.78 1.02
k_scale[2] 0.38 0.29 0.28 0.06 0.77 129.45 1.02
k_scale[3] 0.30 0.27 0.23 0.05 0.58 106.15 1.05
noise[0] 0.01 0.00 0.01 0.00 0.01 99.56 1.01
noise[1] 0.00 0.00 0.00 0.00 0.01 56.03 1.07
noise[2] 0.00 0.00 0.00 0.00 0.01 85.43 1.04
noise[3] 0.00 0.00 0.00 0.00 0.01 16.82 1.10
from mpl_toolkits.axes_grid1 import make_axes_locatable
for m, v in zip(mean, var):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 6))
im1 = ax1.scatter(indices_unmeasured[:, 1], indices_unmeasured[:, 0], c=m, s=10, marker='s')
ax1.set_aspect('equal')
divider = make_axes_locatable(ax1)
cax = divider.append_axes('right', size='3%', pad=0.1)
fig.colorbar(im1, cax=cax, orientation='vertical')
im2 = ax2.scatter(indices_unmeasured[:, 1], indices_unmeasured[:, 0], c=v, s=10, marker='s')
ax2.set_aspect('equal')
divider = make_axes_locatable(ax2)
cax = divider.append_axes('right', size='3%', pad=0.1)
fig.colorbar(im2, cax=cax, orientation='vertical')
plt.show()
Error in prediction for two channels:
print(jnp.sum((abs(mean - y_unmeasured)), axis=-1))
[1184.67279246 1005.68638224 1411.77420842 1554.34109267]